import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as util

import matplotlib.pyplot as pyplot
import numpy as np
import os

output_img_size = 28
input_dim = 100
channel_num = 1
features_num = 64
batch_size = 64

print(f'prepare datasets begin')
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor

train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
index_verify = range(len(test_dataset))[:5000]
index_test = range(len(test_dataset))[5000:]

sampler_verify = torch.utils.data.sampler.SubsetRandomSampler(index_verify)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(index_test)

verify_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_verify)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=sampler_test)

class AntiCNN(nn.Module):
    def __init__(self):
        super(AntiCNN, self).__init__()
        self.model = nn.Sequential()
        self.model.add_module('deconv1', nn.ConvTranspose2d(input_dim, features_num * 2, 5, 2, 0, bias=False))
        self.model.add_module('batch_norm1', nn.BatchNorm2d(features_num * 2))
        self.model.add_module('relu1', nn.ReLU(True))
        self.model.add_module('deconv2', nn.ConvTranspose2d(features_num * 2, features_num, 5, 2, 0, bias=False))
        self.model.add_module('batch_norm2', nn.BatchNorm2d(features_num))
        self.model.add_module('relu2', nn.ReLU(True))
        self.model.add_module('deconv3', nn.ConvTranspose2d(features_num, channel_num, 4, 2, 0, bias=False))
        self.model.add_module('sigmoid', nn.Sigmoid())
    
    def forward(self, input):
        output = input
        for _, module in self.model.named_children():
            output = module(output)
        return output
    
def weight_init(module):
    class_name = module.__class__.__name__
    if class_name.find('conv') != -1:
        module.weight.data.normal_(0, 0.02) # convey mean and std
    if class_name.find('norm') != -1:
        module.weight.data.normal_(1, 0.02)
        
def resize_to_img(img):
    return img.data.expand(batch_size, 3, output_img_size, output_img_size)

def imgshow(input, title=None):
    if input.size()[0] > 1:
        input = input.numpy().transpose((1, 2, 0))
    else:
        input = input[0].numpy()
    min_val, max_val = np.amin(input), np.amax(input)
    if max_val > min_val:
        input = (input - min_val) / (max_val - min_val)
    pyplot.imshow(input)
    if title:
        pyplot.title(title)
    pyplot.pause(0.001)

def main():
    net = AntiCNN()
    net = net.cuda() if use_cuda else net
    criterion = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
    
    samples = np.random.choice(10, batch_size)
    samples = torch.from_numpy(samples).type(dtype)
    
    step = 0
    num_epoch = 2
    record = []
    print('train begin')
    for epoch in range(num_epoch):
        print(f'the no.{epoch} epoch')
        train_loss = []
        for batch_index, (data, target) in enumerate(train_loader):
            target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
            #target, data = target.cuda(), data.cuda() if use_cuda else target, data
            if use_cuda:
                target, data = target.cuda(), data.cuda()
            data = data.type(dtype)
            data = data.resize(data.size()[0], 1, 1, 1)
            data = data.expand(data.size()[0], input_dim, 1, 1)
            
            net.train()
            output = net(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1
            loss = loss.cpu() if use_cuda else loss
            train_loss.append(loss.data.numpy())
            if batch_index % 300 == 0:
                net.eval()
                verify_loss = []
                index = 0
                for data, target in verify_loader:
                    target, data = data.clone().detach().requires_grad_(True), target.clone().detach()
                    index += 1
                    # target, data = target.cuda(), data.cuda() if use_cuda else target, data
                    if use_cuda:
                        target, data = target.cuda(), data.cuda()
                    data = data.type(dtype)
                    data = data.resize(data.size()[0], 1, 1, 1)
                    data = data.expand(data.size()[0], input_dim, 1, 1)
                    output = net(data)
                    loss = criterion(output, target)
                    loss = loss.cpu() if use_cuda else loss
                    verify_loss.append(loss.data.numpy())
                print(f'now no.{batch_index} batch. train loss:{np.mean(train_loss):.4f}, verify loss:{np.mean(verify_loss):.4f}')
                record.append([np.mean(train_loss), np.mean(verify_loss)])
    with torch.no_grad():
        samples.resize_(batch_size, 1, 1, 1)
    samples = samples.data.expand(batch_size, input_dim, 1, 1)
    # samples = samples.cuda() if use_cuda else samples
    if use_cuda:
        samples = samples.cuda()
    fake_u = net(samples)
    # fake_u = fake_u.cuda() if use_cuda else fake_u
    if use_cuda:
        fake_u = fake_u.cuda()
    img = resize_to_img(fake_u)
    os.makedirs(os.path.realpath('./pytorch/jizhi/image_generate/temp1'), exist_ok=True)
    util.save_image(img, os.path.realpath(f'./pytorch/jizhi/image_generate/temp1/fake{epoch}.png'))
    pyplot.show()

if __name__ == '__main__':
    main()